import torch
import pandas as pd
import numpy as np
from sklearn import preprocessing
from sklearn import gaussian_process
import matplotlib.pyplot as plt
import time
import sqlite3
torch.set_default_tensor_type('torch.cuda.DoubleTensor')
from MCMC.Utility import LoadData, get_inverse_block
import os

DataName = 'Data1969_09'
OutcomeCol = 'sec_c'

X,A,Y = LoadData(DataName,OutcomeCol)

# Change A=1 to A = 2 as there are very few observations
A[A<2] = 2

obs_num = torch.bincount(A)[2:]
n2 = sum(obs_num)
n23 = sum(obs_num[1:])
n34 = sum(obs_num[2:])
n45 = sum(obs_num[3:])


Matern_length_scale = 1.5
Matern_nu = 1.5
sig_latent = 2
sig_noise = 1
# compute the cov matrix

Kernel = gaussian_process.kernels.Matern(length_scale=Matern_length_scale, nu=Matern_nu)
Corr_mat = torch.Tensor(Kernel(X.cpu()))

Cov_mat_F = sig_latent**2 * Corr_mat + torch.eye(n=n2)*0.01
Cov_mat_23 = sig_latent**2 * Corr_mat[:n23,:n23] + torch.eye(n=n23)*0.01
Cov_mat_34 = sig_latent**2 * Corr_mat[:n34,:n34] + torch.eye(n=n34)*0.01
Cov_mat_45 = sig_latent**2 * Corr_mat[:n45,:n45] + torch.eye(n=n45)*0.01

temp_F =  [get_inverse_block(i, Cov_mat_F, n2) for i in range(n2)]
temp_23 = [get_inverse_block(i, Cov_mat_23, n23) for i in range(n23)]
temp_34 = [get_inverse_block(i, Cov_mat_34, n34) for i in range(n34)]
temp_45 = [get_inverse_block(i, Cov_mat_45, n45) for i in range(n45)]
temp_F = [torch.stack([x[0] for x in temp_F]), torch.stack([x[1] for x in temp_F])]
temp_23 = [torch.stack([x[0] for x in temp_23]), torch.stack([x[1] for x in temp_23])]
temp_34 = [torch.stack([x[0] for x in temp_34]), torch.stack([x[1] for x in temp_34])]
temp_45 = [torch.stack([x[0] for x in temp_45]), torch.stack([x[1] for x in temp_45])]

mytensortype = 'torch.DoubleTensor'
torch.set_default_tensor_type(mytensortype)
Paramatrix = torch.zeros(size = (5, n2, 1))
Epsilon0 = Paramatrix[0,:,:]
F0 = Paramatrix[1,:,:]
T230 = Paramatrix[2,:,:]
T340 = Paramatrix[3,:,:]
T450 = Paramatrix[4,:,:]

F = F0
T23 = T230[:n23,:]
T34 = T340[:n34,:]
T45 = T450[:n45:,:]
Epsilon = Epsilon0
F[:] = torch.rand(size=(n2,1))-0.5
T23[:] = torch.rand(size=(n23,1))-0.5
T34[:] = torch.rand(size=(n34,1))-0.5
T45[:] = torch.rand(size=(n45,1))-0.5
Epsilon[:] = torch.rand(size=(n2,1))-0.5

if mytensortype == 'torch.DoubleTensor':
  temp_F = [temp_F[0].cpu(), temp_F[1].cpu()]
  temp_23 = [temp_23[0].cpu(), temp_23[1].cpu()]
  temp_34 = [temp_34[0].cpu(), temp_34[1].cpu()]
  temp_45 = [temp_45[0].cpu(), temp_45[1].cpu()]
  Y = Y.cpu()

def one_step(i):
    if Y[i] > 0.5:
        # sample epsilon
        cur_sum = Paramatrix[1:,i,:].sum()
        torch.nn.init.trunc_normal_(Epsilon[i], mean=0, std=sig_noise, a=-float(cur_sum), b=10)
        # sample F
        cur_sum = Paramatrix[0,i,:].sum() + Paramatrix[2:,i,:].sum()
        cond_mean = temp_F[0][i,:]@torch.cat([F[:i], F[(i+1):]])
        cond_var = temp_F[1][i]
        torch.nn.init.trunc_normal_(F[i], mean=cond_mean, std=cond_var**0.5, a=-float(cur_sum), b=10)
        if i in range(0, n23):
            # sample T23
            cur_sum = Paramatrix[:2,i,:].sum() + Paramatrix[3:,i,:].sum()
            cond_mean = temp_23[0][i,:]@torch.cat([T23[:i], T23[(i+1):]])
            cond_var = temp_23[1][i]
            torch.nn.init.trunc_normal_(T23[i], mean=cond_mean, std=cond_var**0.5, a=-float(cur_sum), b=10)
        if i in range(0, n34):
            # sample T34
            cur_sum = Paramatrix[:3,i,:].sum() + Paramatrix[4:,i,:].sum()
            cond_mean = temp_34[0][i,:]@torch.cat([T34[:i], T34[(i+1):]])
            cond_var = temp_34[1][i]
            torch.nn.init.trunc_normal_(T34[i], mean=cond_mean, std=cond_var**0.5, a=-float(cur_sum), b=10)
        if i in range(0, n45):
            # sample T23
            cur_sum = Paramatrix[:4,i,:].sum() + Paramatrix[5:,i,:].sum()
            cond_mean = temp_45[0][i,:]@torch.cat([T45[:i], T45[(i+1):]])
            cond_var = temp_45[1][i]
            torch.nn.init.trunc_normal_(T45[i], mean=cond_mean, std=cond_var**0.5, a=-float(cur_sum), b=10)
    else:
        cur_sum = Paramatrix[1:,i,:].sum()
        torch.nn.init.trunc_normal_(Epsilon[i], mean=0, std=sig_noise, a=-10, b=-float(cur_sum))
        # sample F
        cur_sum = Paramatrix[0,i,:].sum() + Paramatrix[2:,i,:].sum()
        cond_mean = temp_F[0][i,:]@torch.cat([F[:i], F[(i+1):]])
        cond_var = temp_F[1][i]
        torch.nn.init.trunc_normal_(F[i], mean=cond_mean, std=cond_var**0.5, a=-10, b=-float(cur_sum))
        if i in range(0, n23):
            # sample T23
            cur_sum = Paramatrix[:2,i,:].sum() + Paramatrix[3:,i,:].sum()
            cond_mean = temp_23[0][i,:]@torch.cat([T23[:i], T23[(i+1):]])
            cond_var = temp_23[1][i]
            torch.nn.init.trunc_normal_(T23[i], mean=cond_mean, std=cond_var**0.5, a=-10, b=-float(cur_sum))
        if i in range(0, n34):
            # sample T34
            cur_sum = Paramatrix[:3,i,:].sum() + Paramatrix[4:,i,:].sum()
            cond_mean = temp_34[0][i,:]@torch.cat([T34[:i], T34[(i+1):]])
            cond_var = temp_34[1][i]
            torch.nn.init.trunc_normal_(T34[i], mean=cond_mean, std=cond_var**0.5, a=-10, b=-float(cur_sum))
        if i in range(0, n45):
            # sample T23
            cur_sum = Paramatrix[:4,i,:].sum() + Paramatrix[5:,i,:].sum()
            cond_mean = temp_45[0][i,:]@torch.cat([T45[:i], T45[(i+1):]])
            cond_var = temp_45[1][i]
            torch.nn.init.trunc_normal_(T45[i], mean=cond_mean, std=cond_var**0.5, a=-10, b=-float(cur_sum))

# Gibbs sampling
R = 4000
result = [torch.empty(size=(R,n2)), torch.empty(size=(R,n2)),torch.empty(size=(R,n23)),torch.empty(size=(R,n34)),torch.empty(size=(R,n45))]
usage = 0
rep = 0
while rep < R:
    for i in range(0, n2):
        one_step(i)

    result[0][rep,:] = Epsilon.squeeze()
    result[1][rep,:] = F.squeeze()
    result[2][rep,:] = T23.squeeze()
    result[3][rep,:] = T34.squeeze()
    result[4][rep,:] = T45.squeeze()

    rep += 1
    print(rep)


torch.set_default_tensor_type('torch.cuda.DoubleTensor')

for i in range(result.__len__()):
    result[i] = result[i].cuda()

def Extrapolation(observed):
    # observed: R * n matrix
    n = observed.shape[1]
    Sig22_inverse = torch.linalg.inv(Cov_mat_F[:n,:n])
    Sig12 = Cov_mat_F[n:,:n]
    Sig11 = Cov_mat_F[n:,n:]
    cond_mean = Sig12 @ Sig22_inverse @ observed.T
    cond_cov = Sig11 - Sig12 @ Sig22_inverse @ Sig12.T
    gen_data = torch.randn(size=cond_mean.shape)
    gen_data = torch.linalg.cholesky(cond_cov) @ gen_data
    gen_data += cond_mean
    return gen_data.T

Cov_mat_F = Cov_mat_F.cuda()
# Extrapolation
Extrap_23 = torch.cat([result[2],Extrapolation(result[2])], axis=1)
Extrap_34 = torch.cat([result[3],Extrapolation(result[3])], axis=1)
Extrap_45 = torch.cat([result[4],Extrapolation(result[4])], axis=1)

TotalResult = torch.stack([result[0], result[1], Extrap_23, Extrap_34, Extrap_45])


myfilename = '_'.join([DataName,OutcomeCol,'TotalResult',str(R), str(int(100*Matern_length_scale)), str(int(100*Matern_nu)), str(int(100*sig_latent)), str(int(100*sig_noise))]) + '.pt'
torch.save(TotalResult, os.path.join('Database/MCMCSample', myfilename))
